from datasets import load_dataset
from distilabel.models import vLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
from distilabel.steps import (KeepColumns, FormatTextGenerationSFT)
import shutil
import os
import pandas as pd


pipeline_cache = '/root/.cache/distilabel/pipelines/distill-qwen-32b-r1-exploretom'
if os.path.exists(pipeline_cache):
    shutil.rmtree(pipeline_cache)


prompt_template = """\
You will be given a story and a question. Please reason step by step to answer this question, and put your final answer within \\boxed{}:

Story: {{ story_structure }}
Question: {{question}}
"""

dataset = load_dataset("csv", data_files=".../ToM_data/ExploreToM/ExploreToM-data-sample.csv", split="train[:1000]")




def add_combined_column(dataset):
    def combine_text(example):
        # Create combined text
        example["entire_instruction"] = f"Story: {example['story_structure']} Question: {example['question']}"
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)

# Apply the function to your dataset
dataset = add_combined_column(dataset)
print(dataset)
print(dataset[0])



model_id = ".../distill32B"

with Pipeline(
    name="distill-qwen-32b-r1-exploretom",
    description="A pipeline to generate data from a distilled r1 model",
) as pipeline:

    llm = vLLM(
        model=model_id,
        tokenizer=model_id,
        extra_kwargs={
            "tensor_parallel_size": 1,
            "max_model_len": 16384,
        },
        generation_kwargs={
            "temperature": 0.7,
            "max_new_tokens": 16384,
        },
    )


  
    text_generation = TextGeneration(
        llm=llm, 
        template=prompt_template,
        num_generations=2,
        input_batch_size=4,
        columns = ["story_structure", "question"],
    )

    
    
    format_sft = FormatTextGenerationSFT(input_mappings={"instruction": "entire_instruction"})
    

    text_generation.connect(format_sft)
    




if __name__ == "__main__":
    
    distiset = pipeline.run(dataset=dataset)
    print(distiset)
    print(distiset['default']['train'][0]) 
    distiset.save_to_disk(".../SFTData/ExploreToM_May_test")
    distiset.load_from_disk(".../SFTData/ExploreToM_May_test")
    print(distiset)